package com.tpvlog.dfs.client;

import com.tpvlog.dfs.client.req.Host;
import com.tpvlog.dfs.client.req.NetworkRequest;
import com.tpvlog.dfs.client.resp.NetworkResponse;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;

/**
 * 网络连接管理器
 */
public class NetworkManager {
    // 连接状态：正在连接
    public static final Integer CONNECTING = 1;
    // 连接状态：已连接
    public static final Integer CONNECTED = 2;
    // 连接状态：断开连接
    public static final Integer DISCONNECTED = 3;

    // poll超时时间（毫秒）
    public static final Long POLL_TIMEOUT = 500L;
    // 请求超时时长（30秒）
    public static final long REQUEST_TIMEOUT = 30 * 1000;
    // 请求超时检测间隔（1秒）
    public static final long REQUEST_TIMEOUT_CHECK_INTERVAL = 1000;

    // 多路复用Selector
    private Selector selector;

    // 缓存等待建立连接的Host
    private ConcurrentLinkedQueue<Host> waitingConnectHosts;

    // 缓存Host的连接状态
    private Map<Host, Integer> connectState;
    // 缓存Host对应的Channel
    private Map<Host, SelectionKey> connections;

    // 缓存待发送到Host的请求，保存到Host各自的请求队列
    private Map<Host, ConcurrentLinkedQueue<NetworkRequest>> waitingRequests;
    // 缓存当前正要发送到Host的请求
    private Map<Host, NetworkRequest> toSendRequests;

    // 缓存等待处理且已读取完整的响应，Key为请求ID
    private Map<String, NetworkResponse> finishedResponses;
    // 缓存等待处理且未读取完整的响应
    private Map<Host, NetworkResponse> unfinishedResponses;

    public NetworkManager() {
        try {
            this.selector = Selector.open();
        } catch (IOException e) {
            e.printStackTrace();
        }

        this.connections = new ConcurrentHashMap<Host, SelectionKey>();
        this.connectState = new ConcurrentHashMap<Host, Integer>();
        this.waitingConnectHosts = new ConcurrentLinkedQueue<Host>();
        this.waitingRequests = new ConcurrentHashMap<Host, ConcurrentLinkedQueue<NetworkRequest>>();
        this.toSendRequests = new ConcurrentHashMap<Host, NetworkRequest>();
        this.finishedResponses = new ConcurrentHashMap<String, NetworkResponse>();
        this.unfinishedResponses = new ConcurrentHashMap<Host, NetworkResponse>();

        new NetworkPollThread().start();
        new RequestTimeoutCheckThread().start();
    }

    /**
     * 尝试建立连接
     */
    public Boolean tryConnect(Host host) {
        assert host != null;
        // 这里要加锁，防止多个线程同时建立连接
        synchronized (this) {
            if (!connectState.containsKey(host) || connectState.get(host).equals(DISCONNECTED)) {
                connectState.put(host, CONNECTING);
                // 加入HOST队列，等待线程异步出队并建立连接
                waitingConnectHosts.offer(host);
            }
            // 循环等待直到建立连接
            while (connectState.get(host).equals(CONNECTING)) {
                try {
                    // 等待100毫秒
                    wait(100);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
            // 如果连接建立失败
            if (connectState.get(host).equals(DISCONNECTED)) {
                return false;
            }
            return true;
        }
    }

    /**
     * 发送网络请求
     */
    public void sendRequest(NetworkRequest request) {
        // 1.确认这个请求要发给哪个host
        Host host = new Host(request.getHostname(), request.getIp(), request.getNioPort());
        ConcurrentLinkedQueue<NetworkRequest> requestQueue = waitingRequests.get(host);
        // 2.将请求缓存到该host对应的请求队列中
        requestQueue.offer(request);
    }

    /**
     * 等待指定请求的响应
     */
    public NetworkResponse waitResponse(String requestId) throws Exception {
        NetworkResponse response = null;

        while ((response = finishedResponses.get(requestId)) == null) {
            Thread.sleep(100);
        }
        Host host = new Host(response.getHostname(), response.getIp(), response.getNioPort());
        toSendRequests.remove(host);
        finishedResponses.remove(requestId);

        return response;
    }

    /**
     * 网络连接核心线程
     */
    class NetworkPollThread extends Thread {
        @Override
        public void run() {
            while (true) {
                // 1.对各个Host建立连接：关注OP_CONNECT
                tryConnect();
                // 2.对已建立连接的Host，关注OP_WRITE，也就是准备发送请求
                prepareRequests();
                // 3.监听Selector注册的Channel的各类事件，进行处理
                poll();
            }
        }

        /**
         * 尝试与排队中的HOST建立连接
         */
        private void tryConnect() {
            Host host = null;
            SocketChannel channel = null;
            // 不断出队准备建立连接的HOST
            while ((host = waitingConnectHosts.poll()) != null) {
                if (CONNECTED.equals(connectState.get(host))) {
                    continue;
                }
                // 建立SocketChannel，并注册到Selector，关注OP_CONNECT事件
                try {
                    channel = SocketChannel.open();
                    channel.configureBlocking(false);
                    channel.connect(new InetSocketAddress(host.getHostname(), host.getNioPort()));
                    channel.register(selector, SelectionKey.OP_CONNECT);
                } catch (Exception e) {
                    e.printStackTrace();
                    connectState.put(host, DISCONNECTED);
                }
            }
        }

        /**
         * 准备好要发送的请求
         */
        private void prepareRequests() {
            for (Host host : waitingRequests.keySet()) {
                ConcurrentLinkedQueue<NetworkRequest> requestQueue = waitingRequests.get(host);
                if (requestQueue == null || requestQueue.isEmpty()) {
                    continue;
                }
                // 对当前Host派发一个请求，准备发送
                if (!toSendRequests.containsKey(host)) {
                    NetworkRequest request = requestQueue.poll();
                    toSendRequests.put(host, request);
                    SelectionKey key = connections.get(host);
                    // 关注OP_WRITE事件，也就是发送请求
                    key.interestOps(SelectionKey.OP_WRITE);
                }
            }
        }

        /**
         * 完成连接建立、请求发送、响应读取
         */
        private void poll() {
            SocketChannel channel = null;
            try {
                int selectedKeys = selector.select(POLL_TIMEOUT);
                if (selectedKeys <= 0) {
                    return;
                }

                Iterator<SelectionKey> keysIterator = selector.selectedKeys().iterator();
                while (keysIterator.hasNext()) {
                    SelectionKey key = (SelectionKey) keysIterator.next();
                    keysIterator.remove();

                    channel = (SocketChannel) key.channel();
                    // 1.连接事件
                    if (key.isConnectable()) {
                        finishConnect(key, channel);
                    }
                    // 2.发送请求
                    else if (key.isWritable()) {
                        sendRequest(key, channel);
                    }
                    // 3.接受响应
                    else if (key.isReadable()) {
                        readResponse(key, channel);
                    }
                }
            } catch (Exception e) {
                e.printStackTrace();
                if (channel != null) {
                    try {
                        channel.close();
                    } catch (IOException e1) {
                        e1.printStackTrace();
                    }
                }
            }
        }

        /**
         * 建立连接
         */
        private void finishConnect(SelectionKey key, SocketChannel channel) {
            Host host = null;
            try {
                // 阻塞直到连接建立
                if (channel.isConnectionPending()) {
                    while (!channel.finishConnect()) {
                        Thread.sleep(100);
                    }
                }
                host = parseHost(channel);
                System.out.println("与远程服务器" + host + "成功建立连接......");

                waitingRequests.putIfAbsent(host, new ConcurrentLinkedQueue<NetworkRequest>());
                // 关联Host和连接
                connections.put(host, key);
                // 更新host连接状态
                connectState.put(host, CONNECTED);
            } catch (Exception e) {
                e.printStackTrace();
                if (host != null) {
                    connectState.put(host, DISCONNECTED);
                }
            }
        }

        /**
         * 发送请求
         */
        private void sendRequest(SelectionKey key, SocketChannel channel) {
            Host host = null;
            try {
                // 1.解析Host
                host = parseHost(channel);

                // 2.获取要发送到该Host的请求
                NetworkRequest request = toSendRequests.get(host);
                ByteBuffer buffer = request.getBuffer();

                // 3.写请求内容，while处理防止拆包
                channel.write(buffer);
                while (buffer.hasRemaining()) {
                    channel.write(buffer);
                }
                request.setSendTime(System.currentTimeMillis());
                System.out.println("本次向" + host + "机器的请求发送完毕......");

                // 4.关注OP_READ事件，即服务端的响应
                key.interestOps(SelectionKey.OP_READ);
            } catch (Exception e) {
                e.printStackTrace();
                // 出现异常时，取消关注OP_WRITE
                key.interestOps(key.interestOps() & ~SelectionKey.OP_WRITE);

                // 封装一个异常响应
                NetworkRequest request = toSendRequests.get(host);
                NetworkResponse response = new NetworkResponse();
                response.setRequestId(request.getId());
                response.setHostname(host.getHostname());
                response.setIp(host.getIp());
                response.setError(true);
                response.setFinished(true);
                if (request.getNeedResponse()) {
                    finishedResponses.put(request.getId(), response);
                } else {
                    if (request.getCallback() != null) {
                        request.getCallback().process(response);
                    }
                    toSendRequests.remove(host);
                }
            }
        }

        /**
         * 读取响应
         */
        private void readResponse(SelectionKey key, SocketChannel channel) throws Exception {
            Host host = parseHost(channel);

            // 1.获取当前正在发送的请求
            NetworkRequest request = toSendRequests.get(host);

            // 2.针对不同请求类型，解析响应内容
            NetworkResponse response = null;
            if (request.getRequestType().equals(NetworkRequest.REQUEST_SEND_FILE)) {
                response = getSendFileResponse(request.getId(), host, channel);
            } else if (request.getRequestType().equals(NetworkRequest.REQUEST_READ_FILE)) {
                response = getReadFileResponse(request.getId(), host, channel);
            }
            // 如果响应没处理完，出现了拆包，则需要继续监听该Channel的OP_READ事件
            if (Boolean.FALSE.equals(response.getFinished())) {
                return;
            }

            // 3.取消对OP_READ事件关注，即不再关注响应
            key.interestOps(key.interestOps() & ~SelectionKey.OP_READ);

            // 4.处理响应
            if (request.getNeedResponse()) {
                finishedResponses.put(request.getId(), response);
            } else {
                if (request.getCallback() != null) {
                    request.getCallback().process(response);
                }
                toSendRequests.remove(host);
            }
        }
    }

    /**
     * 处理上传文件的响应
     */
    private NetworkResponse getSendFileResponse(String requestId, Host host, SocketChannel channel) throws Exception {
        ByteBuffer buffer = ByteBuffer.allocate(1024);
        channel.read(buffer);
        buffer.flip();

        NetworkResponse response = new NetworkResponse();
        response.setRequestId(requestId);
        response.setHostname(host.getHostname());
        response.setBuffer(buffer);
        response.setError(false);
        response.setFinished(true);
        return response;
    }

    /**
     * 处理下载文件的响应
     */
    private NetworkResponse getReadFileResponse(String requestId, Host host, SocketChannel channel) throws Exception {
        NetworkResponse response = null;

        if (unfinishedResponses.containsKey(host)) {
            response = unfinishedResponses.get(host);
        } else {
            response = new NetworkResponse();
            response.setRequestId(requestId);
            response.setHostname(host.getHostname());
            response.setError(false);
            response.setFinished(false);
        }

        Long fileLength = null;
        if (response.getBuffer() == null) {
            ByteBuffer lengthBuffer = null;
            if (response.getLengthBuffer() == null) {
                lengthBuffer = ByteBuffer.allocate(NetworkRequest.FILE_LENGTH);
                response.setLengthBuffer(lengthBuffer);
            } else {
                lengthBuffer = response.getLengthBuffer();
            }

            channel.read(lengthBuffer);

            if (!lengthBuffer.hasRemaining()) {
                lengthBuffer.rewind();
                fileLength = lengthBuffer.getLong();
            } else {
                unfinishedResponses.put(host, response);
            }
        }

        if (fileLength != null || response.getBuffer() != null) {
            ByteBuffer buffer = null;
            if (response.getBuffer() == null) {
                buffer = ByteBuffer.allocate(Integer.valueOf(
                        String.valueOf(fileLength)));
                response.setBuffer(buffer);
            } else {
                buffer = response.getBuffer();
            }

            channel.read(buffer);

            if (!buffer.hasRemaining()) {
                // 已经完整读取响应
                buffer.rewind();
                response.setFinished(true);
                unfinishedResponses.remove(host);
            } else {
                unfinishedResponses.put(host, response);
            }
        }

        return response;
    }

    /**
     * 超时检测线程
     */
    class RequestTimeoutCheckThread extends Thread {
        @Override
        public void run() {
            while(true) {
                try {
                    long now = System.currentTimeMillis();
                    for(NetworkRequest request : toSendRequests.values()) {
                        // 每个请求的最大等待响应事件为30s
                        if(now - request.getSendTime() > REQUEST_TIMEOUT) {
                            String hostname = request.getHostname();
                            NetworkResponse response = new NetworkResponse();
                            response.setHostname(hostname);
                            response.setIp(request.getIp());
                            response.setRequestId(request.getId());
                            response.setError(true);
                            response.setFinished(true);

                            if(request.getNeedResponse()) {
                                finishedResponses.put(request.getId(), response);
                            } else {
                                if(request.getCallback() != null) {
                                    request.getCallback().process(response);
                                }
                                toSendRequests.remove(hostname);
                            }
                        }
                    }
                    // 每秒检查1次
                    Thread.sleep(REQUEST_TIMEOUT_CHECK_INTERVAL);
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }
    }

    private Host parseHost(SocketChannel channel) throws IOException {
        InetSocketAddress address = (InetSocketAddress) channel.getRemoteAddress();
        String hostname = address.getHostName();
        String ip = address.getAddress().getHostAddress();
        Integer nioPort = address.getPort();
        return new Host(hostname, ip, nioPort);
    }
}
